/*
 * Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <cstdint>
#include <optional>
#include <string>

namespace tensorrt_llm
{
namespace common
{

class QuantMode
{
    // [WARNING] KEEP BELOW DEFINITION IN SYNC WITH tensorrt_llm/quantization/mode.py
public:
    using BaseType = std::uint32_t;

    explicit constexpr QuantMode(BaseType value) noexcept
        : mValue{value}
    {
    }

    QuantMode() noexcept = default;

    constexpr QuantMode(QuantMode const&) noexcept = default;

    constexpr QuantMode& operator=(QuantMode const& other) noexcept = default;

    static constexpr QuantMode none() noexcept
    {
        return QuantMode(BaseType(0));
    }

    static constexpr QuantMode int4Weights() noexcept
    {
        return QuantMode(BaseType(1u) << 0);
    }

    static constexpr QuantMode int8Weights() noexcept
    {
        return QuantMode(BaseType(1u) << 1);
    }

    static constexpr QuantMode activations() noexcept
    {
        return QuantMode(BaseType(1u) << 2);
    }

    static constexpr QuantMode perChannelScaling() noexcept
    {
        return QuantMode(BaseType(1u) << 3);
    }

    static constexpr QuantMode perTokenScaling() noexcept
    {
        return QuantMode(BaseType(1u) << 4);
    }

    static constexpr QuantMode perGroupScaling() noexcept
    {
        return QuantMode(BaseType(1u) << 5);
    }

    static constexpr QuantMode int8KvCache() noexcept
    {
        return QuantMode(BaseType(1u) << 6);
    }

    static constexpr QuantMode fp8KvCache() noexcept
    {
        return QuantMode(BaseType(1u) << 7);
    }

    static constexpr QuantMode fp8Qdq() noexcept
    {
        return QuantMode(BaseType(1u) << 8);
    }

    static constexpr QuantMode fp8RowWise() noexcept
    {
        return QuantMode(BaseType(1u) << 3 | BaseType(1u) << 4 | BaseType(1u) << 9);
    }

    static constexpr QuantMode fp8BlockScales() noexcept
    {
        return QuantMode(BaseType(1u) << 10);
    }

    static constexpr QuantMode w4a8QServe() noexcept
    {
        return QuantMode(BaseType(1u) << 11);
    }

    static constexpr QuantMode nvfp4() noexcept
    {
        return QuantMode(BaseType(1u) << 12);
    }

    static constexpr QuantMode fp4KvCache() noexcept
    {
        return QuantMode(BaseType(1u) << 13);
    }

    static constexpr QuantMode w4a8Mxfp4Fp8() noexcept
    {
        return QuantMode(BaseType(1u) << 14);
    }

    static constexpr QuantMode w4a8Mxfp4Mxfp8() noexcept
    {
        return QuantMode(BaseType(1u) << 15);
    }

    static constexpr QuantMode w4a16Mxfp4() noexcept
    {
        return QuantMode(BaseType(1u) << 16);
    }

    constexpr BaseType value() const noexcept
    {
        return mValue;
    }

    constexpr bool isSet(QuantMode const& mode) const noexcept
    {
        return (mValue & mode.value()) == mode.value();
    }

    constexpr bool hasInt4Weights() const noexcept
    {
        return isSet(int4Weights());
    }

    constexpr bool hasInt8Weights() const noexcept
    {
        return isSet(int8Weights());
    }

    constexpr bool hasActivations() const noexcept
    {
        return isSet(activations());
    }

    constexpr bool hasPerChannelScaling() const noexcept
    {
        return isSet(perChannelScaling());
    }

    constexpr bool hasPerTokenScaling() const noexcept
    {
        return isSet(perTokenScaling());
    }

    constexpr bool hasPerGroupScaling() const noexcept
    {
        return isSet(perGroupScaling());
    }

    constexpr bool hasStaticActivationScaling() const noexcept
    {
        return !hasPerTokenScaling();
    }

    constexpr bool hasInt8KvCache() const noexcept
    {
        return isSet(int8KvCache());
    }

    constexpr bool hasFp8KvCache() const noexcept
    {
        return isSet(fp8KvCache());
    }

    constexpr bool hasFp4KvCache() const noexcept
    {
        return isSet(fp4KvCache());
    }

    constexpr bool hasFp8Qdq() const noexcept
    {
        return isSet(fp8Qdq());
    }

    constexpr bool hasFp8RowWise() const noexcept
    {
        return isSet(fp8RowWise());
    }

    constexpr bool hasNvfp4() const noexcept
    {
        return isSet(nvfp4());
    }

    constexpr bool hasW4a8Mxfp4Fp8() const noexcept
    {
        return isSet(w4a8Mxfp4Fp8());
    }

    constexpr bool hasW4a8Mxfp4Mxfp8() const noexcept
    {
        return isSet(w4a8Mxfp4Mxfp8());
    }

    constexpr bool hasW4a16Mxfp4() const noexcept
    {
        return isSet(w4a16Mxfp4());
    }

    constexpr bool hasKvCacheQuant() const noexcept
    {
        return hasInt8KvCache() || hasFp8KvCache() || hasFp4KvCache();
    }

    static constexpr QuantMode fromDescription(bool quantizeWeights, bool quantizeActivations, bool perToken,
        bool perChannel, bool perGroup, bool useInt4Weights, bool useInt8KvCache, bool useFp8KvCache, bool useFp8Qdq,
        bool useFp8RowWise, bool useW4a8QServe, bool useFp4Quant, bool useFp8BlockScales, bool useW4a8Mxfp4Fp8,
        bool useW4a8Mxfp4Mxfp8, bool useW4a16Mxfp4)
    {
        QuantMode quantMode{};
        if (quantizeWeights)
        {
            if (useInt4Weights)
                quantMode += int4Weights();
            else
                quantMode += int8Weights();
        }

        if (quantizeActivations)
        {
            quantMode += activations();
        }

        if (perChannel)
        {
            quantMode += QuantMode::perChannelScaling();
        }
        if (perToken)
        {
            quantMode += QuantMode::perTokenScaling();
        }
        if (perGroup)
        {
            quantMode += QuantMode::perGroupScaling();
        }

        if (useInt8KvCache)
        {
            quantMode += int8KvCache();
        }

        if (useFp8KvCache)
        {
            quantMode += fp8KvCache();
        }

        if (useFp8Qdq)
        {
            quantMode += fp8Qdq();
        }

        if (useFp8RowWise)
        {
            quantMode += fp8RowWise();
        }

        if (useFp8BlockScales)
        {
            quantMode += fp8BlockScales();
        }

        if (useW4a8QServe)
        {
            quantMode += w4a8QServe();
        }

        if (useFp4Quant)
        {
            quantMode += nvfp4();
        }

        if (useW4a8Mxfp4Fp8)
        {
            quantMode += w4a8Mxfp4Fp8();
        }

        if (useW4a8Mxfp4Mxfp8)
        {
            quantMode += w4a8Mxfp4Mxfp8();
        }

        if (useW4a16Mxfp4)
        {
            quantMode += w4a16Mxfp4();
        }

        return quantMode;
    }

    static constexpr QuantMode useSmoothQuant(bool perToken = false, bool perChannel = false)
    {
        return fromDescription(true, true, perToken, perChannel, false, false, false, false, false, false, false, false,
            false, false, false, false);
    }

    static constexpr QuantMode useQServe(bool perGroup)
    {
        return fromDescription(true, true, false, false, perGroup, true, false, false, false, false, true, false, false,
            false, false, false);
    }

    static constexpr QuantMode useWeightOnly(bool useInt4Weights = false, bool perGroup = false)
    {
        return fromDescription(true, false, false, false, perGroup, useInt4Weights, false, false, false, false, false,
            false, false, false, false, false);
    }

    static QuantMode const fromQuantAlgo(
        std::optional<std::string> quantAlgo = std::nullopt, std::optional<std::string> kvCacheQuantAlgo = std::nullopt)
    {
        QuantMode quantMode{};
        if (quantAlgo == "W8A16")
        {
            quantMode = useWeightOnly(false, false);
        }
        else if (quantAlgo == "W4A16")
        {
            quantMode = useWeightOnly(true, false);
        }
        else if (quantAlgo == "W4A16_AWQ")
        {
            quantMode = useWeightOnly(true, true);
        }
        else if (quantAlgo == "W4A8_AWQ")
        {
            quantMode = useWeightOnly(true, true);
        }
        else if (quantAlgo == "W4A8_QSERVE_PER_GROUP")
        {
            quantMode = useQServe(false);
        }
        else if (quantAlgo == "W4A8_QSERVE_PER_CHANNEL")
        {
            quantMode = useQServe(true);
        }
        else if (quantAlgo == "W4A16_GPTQ")
        {
            quantMode = useWeightOnly(true, true);
        }
        else if (quantAlgo == "W8A8_SQ_PER_CHANNEL")
        {
            quantMode = useSmoothQuant(false, true);
        }
        else if (quantAlgo == "W8A8_SQ_PER_TENSOR_PLUGIN")
        {
            quantMode = useSmoothQuant(false, false);
        }
        else if (quantAlgo == "W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN")
        {
            quantMode = useSmoothQuant(true, true);
        }
        else if (quantAlgo == "W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN")
        {
            quantMode = useSmoothQuant(false, true);
        }
        else if (quantAlgo == "W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN")
        {
            quantMode = useSmoothQuant(true, false);
        }
        else if (quantAlgo == "FP8")
        {
            quantMode = fromDescription(false, false, false, false, false, false, false, false, true, false, false,
                false, false, false, false, false);
        }
        else if (quantAlgo == "FP8_ROWWISE")
        {
            quantMode = fromDescription(false, false, true, true, false, false, false, false, false, true, false, false,
                false, false, false, false);
        }
        else if (quantAlgo == "FP4")
        {
            quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
                true, false, false, false, false);
        }
        else if (quantAlgo == "FP8_BLOCK_SCALES")
        {
            quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
                false, true, false, false, false);
        }
        else if (quantAlgo == "W4A8_MXFP4_FP8")
        {
            quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
                false, false, true, false, false);
        }
        else if (quantAlgo == "W4A8_MXFP4_MXFP8")
        {
            quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
                false, false, false, true, false);
        }
        else if (quantAlgo == "W4A16_MXFP4")
        {
            quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
                false, false, false, false, true);
        }

        if (kvCacheQuantAlgo == "INT8")
        {
            quantMode += int8KvCache();
        }
        else if (kvCacheQuantAlgo == "FP8")
        {
            quantMode += fp8KvCache();
        }
        else if (kvCacheQuantAlgo == "NVFP4")
        {
            quantMode += fp4KvCache();
        }

        return quantMode;
    }

    constexpr QuantMode operator+(QuantMode const& other) const noexcept
    {
        return QuantMode(mValue | other.mValue);
    }

    constexpr QuantMode& operator+=(QuantMode const& other) noexcept
    {
        return *this = *this + other;
    }

    constexpr QuantMode operator-(QuantMode const& other) const noexcept
    {
        return QuantMode(mValue & ~other.mValue);
    }

    constexpr QuantMode& operator-=(QuantMode const& other) noexcept
    {
        return *this = *this - other;
    }

    constexpr bool operator==(QuantMode const& other) const noexcept
    {
        return mValue == other.mValue;
    }

    constexpr bool operator!=(QuantMode const& other) const noexcept
    {
        return !(*this == other);
    }

private:
    BaseType mValue{0};
};

class GroupwiseQuantAlgo
{
    // Keep align with tensorrt_llm/quantization/mode.py
public:
    static constexpr int BIAS = int(1) << 0;
    static constexpr int ZERO = int(1) << 1;
    static constexpr int PRE_QUANT_SCALE = int(1) << 2;
    static constexpr int FP8_ALPHA = int(1) << 3;
    static constexpr int INT8_WEIGHT = int(1) << 4;
};

} // namespace common
} // namespace tensorrt_llm
